#!/usr/bin/env python
# -*- coding: utf-8 -*-
"""
@file    test.py
@brief
@details
@author  Shivelino
@date    2023-12-23 19:10
@version 0.0.1

@par Copyright(c):
@par todo:
@par history:
"""
import torch
import  os
import argparse
import matplotlib.pyplot as plt
import numpy as np

from nets import get_model
from utils import get_device, get_dataloader_mnist
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'


def test(opt):
    # get dataloader
    _, testloader = get_dataloader_mnist(opt.data_dir, opt.batch_size)

    # init model
    device = get_device()
    print(f"Current Model: {opt.model}")
    model = get_model(opt.model).to(device)
    model.load_state_dict(torch.load(f'model/model_{opt.model}.pth'))

    # 将模型设置为评估模式
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for i, data in enumerate(testloader):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()

    accuracy = correct / total
    print(f'Accuracy on the test set: {accuracy * 100:.2f}%')


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--model', type=str, default="lenet", help='model')
    parser.add_argument('--data_dir', type=str, default="data", help='data directory')
    parser.add_argument('--batch_size', type=int, default=128, help='size of the batches')
    test(parser.parse_args())
